import torch

q = torch.rand(3, 5, 2)
l = torch.zeros(q.shape[:-1]).unsqueeze(-1)
print(l.shape)




